import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.autograd as autograd
from torch.autograd import Variable
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
# Seed control, for better reproducibility
# NOTE: this does not gurantee results are always the same
seed = 22
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
device = torch.device("cuda:0")
# work on a single GPU or CPU
cudnn.benchmark=True
generator.cuda()
discriminator.cuda()
adversarial_loss.cuda()
Tensor = torch.cuda.FloatTensor
else:
device = torch.device("cpu")
cudnn.benchmark=False
Tensor = torch.FloatTensor
print(device)
EPOCHS = 500
cpu
def imshow(img):
# custom show in order to display
# torch tensors as numpy
npimg = img.numpy() / 2 + 0.5 # from tensor to numpy
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
batch_size = 250 # might try to use large batches (we will discuss why later when we talk about BigGAN)
# NOTE: the batch_size should be an integer divisor of the data set size or torch
# will give you an error regarding batch sizes of "0" when the data loader tries to
# load in the final batch
dataset = dset.CIFAR10(root='data/cifar/', download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
# frogs are the sixth class in the dataset
classes = ['plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck']
frog = 6
frog_index = [i for i, x in enumerate(dataset.targets) if x == 6]
print("number of frog imgs: ", len(frog_index))
frog_set = torch.utils.data.Subset(dataset, frog_index)
dataloader = torch.utils.data.DataLoader(frog_set, batch_size=batch_size,
shuffle=True, num_workers=1)
Files already downloaded and verified number of frog imgs: 5000
# get some random training images
dataiter = iter(dataloader)
real_image_examples, _ = dataiter.next()
# show images
plt.figure(figsize=(10,10))
imshow(torchvision.utils.make_grid(real_image_examples, nrow=int(np.sqrt(batch_size))))
print("Image shape: ", real_image_examples[0].size())
Image shape: torch.Size([3, 32, 32])
In this implementation of GANS, we will use a few of the tricks from F. Chollet and from Salimans et al. In particular, we will add some noise to the labels.
latent_dim = 32
height = 32
width = 32
channels = 3
# Note: according to Radford (2016), is there anything done here
# that potentially could have been different?
#
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
# First, transform the input into a 8x8 128-channels feature map
self.init_size = width // 4 # one quarter the image size
self.l1 = nn.Sequential( nn.Linear(latent_dim, 128 * self.init_size ** 2) )
# there is no reshape layer, this will be done in forward function
# alternately we could us only the functional API
# and bypass sequential altogether
# we will use the sequential API
# in order to create some blocks
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2, mode='bilinear'), #16x16
nn.Conv2d(128, 128, 3, padding=1), #16x16
# Then, add a convolution layer
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
# Upsample to 32x32
nn.Upsample(scale_factor=2, mode='bilinear'), # 32x32
nn.ConvTranspose2d(128, 64, 3, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
# Produce a 32x32xRGB-channel feature map
nn.Conv2d(64, channels, kernel_size=3, padding=1),
nn.Tanh(),
)
def forward(self, z):
# expand the sampled z to 8x8
out = self.l1(z)
out = torch.reshape(out, (out.shape[0], 128, self.init_size, self.init_size))
# use the view function to reshape the layer output
# old way for earlier Torch versions: out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.conv_blocks(out)
return img
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 16, 3, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
# dropout layer - important, or just slowing down the optimization?
nn.Dropout2d(0.25),
nn.Conv2d(16, 32, 3, stride=2, padding=1),
nn.LeakyReLU(0.3, inplace=True),
nn.Dropout2d(0.25),
nn.BatchNorm2d(32, 0.8),
nn.Conv2d(32, 64, 3, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout2d(0.25),
nn.BatchNorm2d(64, 0.8),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout2d(0.25),
nn.BatchNorm2d(128, 0.8),
)
# The height and width of downsampled image
ds_size = width // 2 ** 4
# Classification layer
self.classification_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1),
nn.Sigmoid())
def forward(self, img):
out = self.model(img)
# use the view function to flatten the layer output
# old way for earlier Torch versions: out = out.view(out.shape[0], -1)
out = torch.flatten(out, start_dim=1) # don't flatten over batch size
validity = self.classification_layer(out)
return validity
# custom weights initialization called on netG and netD
# this function from PyTorch's officail DCGAN example:
# https://github.com/pytorch/examples/blob/master/dcgan/main.py#L112
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02) # filters are zero mean, small STDev
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02) # batch norm is unit mean, small STDev
m.bias.data.fill_(0) # like normal, biases start at zero
generator = Generator()
discriminator = Discriminator()
# To stabilize training, we use learning rate decay
# and gradient clipping (by value) in the optimizer.
clip_value = 1.0 # This value will use in the future training process since
# PyTorch didn't has the feature to set clipvalue for
# RMSprop optimizer.
# set discriminator learning higher than generator
discriminator_optimizer = torch.optim.RMSprop(discriminator.parameters(),
lr=0.0008, weight_decay=1e-8)
gan_optimizer = torch.optim.RMSprop(generator.parameters(), lr=0.0004, weight_decay=1e-8)
# THIS LINE OF CODE DEFINES THE FUNCTION WE WILL USE AS LOSS
adversarial_loss = torch.nn.BCELoss() # binary cross entropy
generator.apply(weights_init)
discriminator.apply(weights_init)
Discriminator(
(model): Sequential(
(0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Dropout2d(p=0.25, inplace=False)
(3): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(4): LeakyReLU(negative_slope=0.3, inplace=True)
(5): Dropout2d(p=0.25, inplace=False)
(6): BatchNorm2d(32, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
(7): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(8): LeakyReLU(negative_slope=0.2, inplace=True)
(9): Dropout2d(p=0.25, inplace=False)
(10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
(11): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(12): LeakyReLU(negative_slope=0.2, inplace=True)
(13): Dropout2d(p=0.25, inplace=False)
(14): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
)
(classification_layer): Sequential(
(0): Linear(in_features=512, out_features=1, bias=True)
(1): Sigmoid()
)
)
iterations = EPOCHS
# Sample random points in the latent space
plot_num_examples = 25
fixed_random_latent_vectors = torch.randn(plot_num_examples, latent_dim, device=device)
img_list = []
total_steps = 0
real_image_numpy = np.transpose(torchvision.utils.make_grid(real_image_examples[:plot_num_examples,:,:,:], padding=2, normalize=False, nrow=5),(0,1,2))
%%time
# TODO: optionally load the checkpoint data here from previous run
# Start training loop
for step in range(iterations):
total_steps = total_steps+1
for i, (imgs, _) in enumerate(dataloader):
#===================================
# GENERATOR OPTIMIZE AND GET LABELS
# Zero out any previous calculated gradients
gan_optimizer.zero_grad()
# Sample random points in the latent space
random_latent_vectors = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))
# Decode them to fake images, through the generator
generated_images = generator(random_latent_vectors)
# Assemble labels that say "all real images"
misleading_targets = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
# Get BCE Loss function
# E[log d(x_fake)]
# want generator output to generate images that are "close" to all "ones"
g_loss = adversarial_loss(discriminator(generated_images), misleading_targets)
# now back propagate to get derivatives
g_loss.backward()
# use gan optimizer to only update the parameters of the generator
# this was setup above to only use the params of generator
gan_optimizer.step()
#===================================
# DISCRIMINATOR OPTIMIZE AND GET LABELS
# Zero out any previous calculated gradients
discriminator_optimizer.zero_grad()
# Combine real images with some generator images
real_images = Variable(imgs.type(Tensor))
combined_images = torch.cat([real_images, generated_images.detach()])
# in the above line, we "detach" the generated images from the generator
# this is to ensure that no needless gradients are calculated
# those parameters wouldn't be updated (because we already defined the optimized parameters)
# but they would be calculated here, which wastes time.
# Assemble labels discriminating real from fake images
labels = torch.cat((
Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False),
Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)
))
# Add small random noise to the labels - important trick!
labels += 0.05 * torch.rand(labels.shape)
# Setup Discriminator loss
# this takes the average of BCE(real images labeled as real) + BCE(fake images labeled as fake)
# E[log d(x_real)] + E[log 1- d(x_fake)]
d_loss = (
adversarial_loss(discriminator(combined_images[:batch_size]), labels[:batch_size]) + \
adversarial_loss(discriminator(combined_images[batch_size:]), labels[batch_size:])
) / 2
# get gradients according to loss above
d_loss.backward()
# optimize the discriminator parameters to better classify images
discriminator_optimizer.step()
# Now Clip weights of discriminator (manually)
for p in discriminator.parameters():
p.data.clamp_(-clip_value, clip_value)
#===================================
# Occasionally save / plot
if step % 10 == 0:
# Print metrics
print('Loss at step %s: D(z_c)=%s, D(G(z_mis))=%s' % (total_steps, d_loss.item(),g_loss.item()))
# save images in a list for display later
with torch.no_grad():
fake_output = generator(fixed_random_latent_vectors).detach().cpu()
img_list.append(torchvision.utils.make_grid(fake_output, padding=2, normalize=True, nrow=5))
# in addition, save off a checkpoint of the current models and images
ims = np.array([np.transpose(np.hstack((i,real_image_numpy)), (2,1,0)) for i in img_list])
np.save('models/gan_models/vanilla_images.npy',ims)
# save the state of the models (will need to recreate upon reloading)
torch.save({'state_dict': generator.state_dict()}, 'models/gan_models/vanilla_gen.pth')
torch.save({'state_dict': discriminator.state_dict()}, 'models/gan_models/vanilla_dis.pth')
/Users/ericlarson/opt/anaconda3/envs/mlenv2021/lib/python3.7/site-packages/torch/nn/functional.py:2494: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details. "See the documentation of nn.Upsample for details.".format(mode))
Loss at step 1: D(z_c)=0.6390590667724609, D(G(z_mis))=0.6314025521278381 Loss at step 11: D(z_c)=0.6982100605964661, D(G(z_mis))=0.6279720664024353 Loss at step 21: D(z_c)=0.7137861251831055, D(G(z_mis))=0.6510453224182129 Loss at step 31: D(z_c)=0.6936168670654297, D(G(z_mis))=0.6344101428985596 Loss at step 41: D(z_c)=0.6873511672019958, D(G(z_mis))=0.6630176901817322 Loss at step 51: D(z_c)=0.7234673500061035, D(G(z_mis))=0.6030287146568298 Loss at step 61: D(z_c)=0.6745731234550476, D(G(z_mis))=0.6478986740112305 Loss at step 71: D(z_c)=0.7010748386383057, D(G(z_mis))=0.6845511198043823 Loss at step 81: D(z_c)=0.6983008980751038, D(G(z_mis))=0.6263479590415955 Loss at step 91: D(z_c)=0.7023251056671143, D(G(z_mis))=0.6641501784324646 Loss at step 101: D(z_c)=0.689208984375, D(G(z_mis))=0.6761195063591003 Loss at step 111: D(z_c)=0.6887251734733582, D(G(z_mis))=0.6745609641075134 Loss at step 121: D(z_c)=0.6865994334220886, D(G(z_mis))=0.6533544063568115 Loss at step 131: D(z_c)=0.6784639358520508, D(G(z_mis))=0.6674836874008179 Loss at step 141: D(z_c)=0.685118556022644, D(G(z_mis))=0.6342769265174866 Loss at step 151: D(z_c)=0.6760877370834351, D(G(z_mis))=0.6201118230819702 Loss at step 161: D(z_c)=0.6607939004898071, D(G(z_mis))=0.6195535659790039 Loss at step 171: D(z_c)=0.666374683380127, D(G(z_mis))=0.7073000073432922 Loss at step 181: D(z_c)=0.6942265033721924, D(G(z_mis))=0.6349639892578125 Loss at step 191: D(z_c)=0.6862442493438721, D(G(z_mis))=0.6657617092132568 Loss at step 201: D(z_c)=0.6605865955352783, D(G(z_mis))=0.67072594165802 Loss at step 211: D(z_c)=0.7073507308959961, D(G(z_mis))=0.656679630279541 Loss at step 221: D(z_c)=0.6781940460205078, D(G(z_mis))=0.6690979599952698 Loss at step 231: D(z_c)=0.6838864088058472, D(G(z_mis))=0.6174855828285217 Loss at step 241: D(z_c)=0.6770867109298706, D(G(z_mis))=0.6095179915428162 Loss at step 251: D(z_c)=0.6648505330085754, D(G(z_mis))=0.673061192035675 Loss at step 261: D(z_c)=0.7031947374343872, D(G(z_mis))=0.6028718948364258 Loss at step 271: D(z_c)=0.7538137435913086, D(G(z_mis))=0.6634283661842346 Loss at step 281: D(z_c)=0.7476599216461182, D(G(z_mis))=0.6353903412818909 Loss at step 291: D(z_c)=0.6578364372253418, D(G(z_mis))=0.7342846393585205 Loss at step 301: D(z_c)=0.7089256048202515, D(G(z_mis))=0.6642106771469116 Loss at step 311: D(z_c)=0.7156796455383301, D(G(z_mis))=0.6909321546554565 Loss at step 321: D(z_c)=0.5975219011306763, D(G(z_mis))=0.5235658288002014 Loss at step 331: D(z_c)=0.6410077810287476, D(G(z_mis))=0.7566168308258057 Loss at step 341: D(z_c)=0.7963037490844727, D(G(z_mis))=0.519284725189209 Loss at step 351: D(z_c)=0.6456691026687622, D(G(z_mis))=0.7836297154426575 Loss at step 361: D(z_c)=0.5810917019844055, D(G(z_mis))=0.7167403101921082 Loss at step 371: D(z_c)=0.7073983550071716, D(G(z_mis))=0.9580529928207397 Loss at step 381: D(z_c)=0.5779390335083008, D(G(z_mis))=1.0108641386032104 Loss at step 391: D(z_c)=0.6122187376022339, D(G(z_mis))=0.8078464865684509 Loss at step 401: D(z_c)=0.5415605306625366, D(G(z_mis))=0.6531441807746887 Loss at step 411: D(z_c)=0.5214414596557617, D(G(z_mis))=0.6158226728439331 Loss at step 421: D(z_c)=0.5780767202377319, D(G(z_mis))=0.7699030637741089 Loss at step 431: D(z_c)=0.7078053951263428, D(G(z_mis))=0.6999310851097107 Loss at step 441: D(z_c)=0.457265168428421, D(G(z_mis))=0.8827989101409912 Loss at step 451: D(z_c)=0.5554391145706177, D(G(z_mis))=1.1486300230026245 Loss at step 461: D(z_c)=0.514870285987854, D(G(z_mis))=0.9928246736526489 Loss at step 471: D(z_c)=0.5272513628005981, D(G(z_mis))=0.7153290510177612 Loss at step 481: D(z_c)=0.6801974177360535, D(G(z_mis))=1.133162498474121 Loss at step 491: D(z_c)=0.42470723390579224, D(G(z_mis))=0.9347643256187439 CPU times: user 1d 6h 4min 41s, sys: 32min 7s, total: 1d 6h 36min 49s Wall time: 12h 47s
# save off everything at the end (same as the checkpoint)
ims = np.array([np.transpose(np.hstack((i,real_image_numpy)), (2,1,0)) for i in img_list])
np.save('models/gan_models/vanilla_images.npy',ims)
# save the state of the models (will need to recreate upon reloading)
torch.save({'state_dict': generator.state_dict()}, 'models/gan_models/vanilla_gen.pth')
torch.save({'state_dict': discriminator.state_dict()}, 'models/gan_models/vanilla_dis.pth')
if True: # load all the models
ims = np.load('models/gan_models/vanilla_images.npy')
generator = Generator()
discriminator = Discriminator()
checkpoint = torch.load('models/gan_models/vanilla_gen.pth')
generator.load_state_dict(checkpoint['state_dict'])
checkpoint = torch.load('models/gan_models/vanilla_dis.pth')
discriminator.load_state_dict(checkpoint['state_dict'])
def norm_grid(im):
# first half should be normalized and second half also, separately
im = im.astype(np.float)
rows,cols,chan = im.shape
cols_over2 = int(cols/2)
tmp = im[:,:cols_over2,:]
im[:,:cols_over2,:] = (tmp-tmp.min())/(tmp.max()-tmp.min())
tmp = im[:,cols_over2:,:]
im[:,cols_over2:,:] = (tmp-tmp.min())/(tmp.max()-tmp.min())
return im
fig = plt.figure(figsize=(12,4))
plt.axis("off")
pls = [[plt.imshow(norm_grid(im), animated=True)] for im in ims]
ani = animation.ArtistAnimation(fig, pls, interval=500, repeat_delay=1000, blit=True)
HTML(ani.to_jshtml())